{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b5789bc3-b1ae-42c7-94a8-2ef4f89946fc",
   "metadata": {},
   "source": [
    "# Lesson 4: Persistence and Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5762271-8736-4e94-9444-8c92bd0e8074",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "\n",
    "_ = load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0168aee-bce9-4d60-b827-f86a88187e31",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "from langgraph.graph import StateGraph, END\n",
    "from typing import TypedDict, Annotated\n",
    "import operator\n",
    "from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_community.tools.tavily_search import TavilySearchResults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da06a64f-a2d5-4a66-8090-9ada0930c684",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "tool = TavilySearchResults(max_results=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2589c5b6-6cc2-4594-9a17-dccdcf676054",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "class AgentState(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], operator.add]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c033522-d2fc-41ac-8e3c-5e35872bf88d",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "from langgraph.checkpoint.sqlite import SqliteSaver\n",
    "\n",
    "memory = SqliteSaver.from_conn_string(\":memory:\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2ba84ec-c172-4de7-ac55-e3158a531b23",
   "metadata": {
    "height": 574
   },
   "outputs": [],
   "source": [
    "class Agent:\n",
    "    def __init__(self, model, tools, checkpointer, system=\"\"):\n",
    "        self.system = system\n",
    "        graph = StateGraph(AgentState)\n",
    "        graph.add_node(\"llm\", self.call_openai)\n",
    "        graph.add_node(\"action\", self.take_action)\n",
    "        graph.add_conditional_edges(\"llm\", self.exists_action, {True: \"action\", False: END})\n",
    "        graph.add_edge(\"action\", \"llm\")\n",
    "        graph.set_entry_point(\"llm\")\n",
    "        self.graph = graph.compile(checkpointer=checkpointer)\n",
    "        self.tools = {t.name: t for t in tools}\n",
    "        self.model = model.bind_tools(tools)\n",
    "\n",
    "    def call_openai(self, state: AgentState):\n",
    "        messages = state['messages']\n",
    "        if self.system:\n",
    "            messages = [SystemMessage(content=self.system)] + messages\n",
    "        message = self.model.invoke(messages)\n",
    "        return {'messages': [message]}\n",
    "\n",
    "    def exists_action(self, state: AgentState):\n",
    "        result = state['messages'][-1]\n",
    "        return len(result.tool_calls) > 0\n",
    "\n",
    "    def take_action(self, state: AgentState):\n",
    "        tool_calls = state['messages'][-1].tool_calls\n",
    "        results = []\n",
    "        for t in tool_calls:\n",
    "            print(f\"Calling: {t}\")\n",
    "            result = self.tools[t['name']].invoke(t['args'])\n",
    "            results.append(ToolMessage(tool_call_id=t['id'], name=t['name'], content=str(result)))\n",
    "        print(\"Back to the model!\")\n",
    "        return {'messages': results}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "876d5092-b8ef-4e38-b4d7-0e80c609bf7a",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "prompt = \"\"\"You are a smart research assistant. Use the search engine to look up information. \\\n",
    "You are allowed to make multiple calls (either together or in sequence). \\\n",
    "Only look up information when you are sure of what you want. \\\n",
    "If you need to look up some information before asking a follow up question, you are allowed to do that!\n",
    "\"\"\"\n",
    "model = ChatOpenAI(model=\"gpt-4o\")\n",
    "abot = Agent(model, [tool], system=prompt, checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10084a02-2928-4945-9f7c-ad3f5b33caf7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "messages = [HumanMessage(content=\"What is the weather in sf?\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "714d1205-f8fc-4912-b148-2a45da99219c",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "thread = {\"configurable\": {\"thread_id\": \"1\"}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83588e70-254f-4f83-a510-c8ae81e729b0",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "for event in abot.graph.stream({\"messages\": messages}, thread):\n",
    "    for v in event.values():\n",
    "        print(v['messages'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cb3ef4c-58b3-401b-b104-0d51e553d982",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "messages = [HumanMessage(content=\"What about in la?\")]\n",
    "thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "for event in abot.graph.stream({\"messages\": messages}, thread):\n",
    "    for v in event.values():\n",
    "        print(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc3293b7-a50c-43c8-a022-8975e1e444b8",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "messages = [HumanMessage(content=\"Which one is warmer?\")]\n",
    "thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "for event in abot.graph.stream({\"messages\": messages}, thread):\n",
    "    for v in event.values():\n",
    "        print(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0722c3d4-4cbf-43bf-81b0-50f634c4ce61",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "messages = [HumanMessage(content=\"Which one is warmer?\")]\n",
    "thread = {\"configurable\": {\"thread_id\": \"2\"}}\n",
    "for event in abot.graph.stream({\"messages\": messages}, thread):\n",
    "    for v in event.values():\n",
    "        print(v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ace59a36-3941-459e-b9d1-ac5a4a1ed3ae",
   "metadata": {},
   "source": [
    "## Streaming tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2f82fe-3ec4-4917-be51-9fb10d1317fa",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver\n",
    "\n",
    "memory = AsyncSqliteSaver.from_conn_string(\":memory:\")\n",
    "abot = Agent(model, [tool], system=prompt, checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0fe1c7-77e2-499c-a2f9-1f739bb6ddf0",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "messages = [HumanMessage(content=\"What is the weather in SF?\")]\n",
    "thread = {\"configurable\": {\"thread_id\": \"4\"}}\n",
    "async for event in abot.graph.astream_events({\"messages\": messages}, thread, version=\"v1\"):\n",
    "    kind = event[\"event\"]\n",
    "    if kind == \"on_chat_model_stream\":\n",
    "        content = event[\"data\"][\"chunk\"].content\n",
    "        if content:\n",
    "            # Empty content in the context of OpenAI means\n",
    "            # that the model is asking for a tool to be invoked.\n",
    "            # So we only print non-empty content\n",
    "            print(content, end=\"|\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98f303b1-a4d0-408c-8cc0-515ff980717f",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
